{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "djUvWu41mtXa"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "su2RaORHpReL"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NztQK2uFpXT-"
      },
      "source": [
        "# 在 TensorBoard 中显示图像数据\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://tensorflow.google.cn/tensorboard/image_summaries\"><img src=\"https://tensorflow.google.cn/images/tf_logo_32px.png\">在 TensorFlow.org 上查看</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/zh-cn/tensorboard/image_summaries.ipynb\"><img src=\"https://tensorflow.google.cn/images/colab_logo_32px.png\">在 Google Colab 中运行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/zh-cn/tensorboard/image_summaries.ipynb\"><img src=\"https://tensorflow.google.cn/images/GitHub-Mark-32px.png\">在 GitHub 上查看源代码</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eDXRFe_qp5C3"
      },
      "source": [
        "## 概述\n",
        "\n",
        "使用 **TensorFlow Image Summary API**，您可以轻松地在 TensorBoard 中记录张量和任意图像并进行查看。这在采样和检查输入数据，或[可视化层权重](http://cs231n.github.io/understanding-cnn/)和[生成的张量](https://hub.packtpub.com/generative-adversarial-networks-using-keras/)方面非常实用。您还可以将诊断数据记录为图像，这在模型开发过程中可能会有所帮助。\n",
        "\n",
        "在本教程中，您将了解如何使用 Image Summary API 将张量可视化为图像。您还将了解如何获取任意图像，将其转换为张量并在 TensorBoard 中进行可视化。教程将通过一个简单而真实的示例，向您展示使用图像摘要了解模型性能。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dG-nnZK9qW9z"
      },
      "source": [
        "## 设置"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "3U5gdCw_nSG3"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "TensorFlow 2.x selected.\n"
          ]
        }
      ],
      "source": [
        "try:\n",
        "  # %tensorflow_version only exists in Colab.\n",
        "  %tensorflow_version 2.x\n",
        "except Exception:\n",
        "  pass\n",
        "\n",
        "# Load the TensorBoard notebook extension.\n",
        "%load_ext tensorboard"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "1qIKtOBrqc9Y"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "TensorFlow version:  2.2\n"
          ]
        }
      ],
      "source": [
        "from datetime import datetime\n",
        "import io\n",
        "import itertools\n",
        "from packaging import version\n",
        "from six.moves import range\n",
        "\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import sklearn.metrics\n",
        "\n",
        "print(\"TensorFlow version: \", tf.__version__)\n",
        "assert version.parse(tf.__version__).release[0] >= 2, \\\n",
        "    \"This notebook requires TensorFlow 2.0 or above.\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tq0gyXOGZ3-h"
      },
      "source": [
        "# 下载 Fashion-MNIST 数据集\n",
        "\n",
        "您将构造一个简单的神经网络，用于对 [Fashion-MNIST](https://research.zalando.com/welcome/mission/research-projects/fashion-mnist/) 数据集中的图像进行分类。此数据集包含 70,000 个 28x28 灰度时装产品图像，来自 10 个类别，每个类别 7,000 个图像。\n",
        "\n",
        "首先，下载数据："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "VmEQwCon3i7m"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz\n",
            "32768/29515 [=================================] - 0s 0us/step\n",
            "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz\n",
            "26427392/26421880 [==============================] - 0s 0us/step\n",
            "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz\n",
            "8192/5148 [===============================================] - 0s 0us/step\n",
            "Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz\n",
            "4423680/4422102 [==============================] - 0s 0us/step\n"
          ]
        }
      ],
      "source": [
        "# Download the data. The data is already divided into train and test.\n",
        "# The labels are integers representing classes.\n",
        "fashion_mnist = keras.datasets.fashion_mnist\n",
        "(train_images, train_labels), (test_images, test_labels) = \\\n",
        "    fashion_mnist.load_data()\n",
        "\n",
        "# Names of the integer classes, i.e., 0 -> T-short/top, 1 -> Trouser, etc.\n",
        "class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', \n",
        "    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qNsjMY0364j4"
      },
      "source": [
        "## 可视化单个图像\n",
        "\n",
        "为了解 Image Summary API 的工作原理，现在您将在 TensorBoard 中记录训练集中的第一个训练图像。\n",
        "\n",
        "在此之前，请检查训练数据的形状："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "FxMPcdmvBn9t"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Shape:  (28, 28)\n",
            "Label:  9 -> Ankle boot\n"
          ]
        }
      ],
      "source": [
        "print(\"Shape: \", train_images[0].shape)\n",
        "print(\"Label: \", train_labels[0], \"->\", class_names[train_labels[0]])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4F8zbUKfBuUt"
      },
      "source": [
        "请注意，数据集中每个图像的形状均为 2 秩张量，形状为 (28, 28)，分别表示高度和宽度。\n",
        "\n",
        "但是，`tf.summary.image()` 需要一个包含 `(batch_size, height, width, channels)` 的 4 秩张量。因此，需要重塑张量。\n",
        "\n",
        "您仅记录一个图像，因此 `batch_size` 为 1。图像为灰度图，因此将 `channels` 设置为 1。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5yPh-7EWB8IK"
      },
      "outputs": [],
      "source": [
        "# Reshape the image for the Summary API.\n",
        "img = np.reshape(train_images[0], (-1, 28, 28, 1))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JAdJDY3FCCwt"
      },
      "source": [
        "现在，您可以在 TensorBoard 中记录此图像并进行查看了。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IJNpyVyxbVtT"
      },
      "outputs": [],
      "source": [
        "# Clear out any prior log data.\n",
        "!rm -rf logs\n",
        "\n",
        "# Sets up a timestamped log directory.\n",
        "logdir = \"logs/train_data/\" + datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
        "# Creates a file writer for the log directory.\n",
        "file_writer = tf.summary.create_file_writer(logdir)\n",
        "\n",
        "# Using the file writer, log the reshaped image.\n",
        "with file_writer.as_default():\n",
        "  tf.summary.image(\"Training data\", img, step=0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rngALbRogXe6"
      },
      "source": [
        "现在，使用 TensorBoard 检查图像。等待几秒，直至界面出现。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T_X-wIy-lD9f"
      },
      "outputs": [],
      "source": [
        "%tensorboard --logdir logs/train_data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c8n8YqGlT3-c"
      },
      "source": [
        " <img src=\"https://github.com/tensorflow/tensorboard/blob/master/docs/images/images_single.png?raw=1\" class=\"\"> "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "34enxJjjgWi7"
      },
      "source": [
        "“Images”标签页显示您刚刚记录的图像。这是一只“短靴”。\n",
        "\n",
        "图像会缩放到默认大小，以方便查看。如果要查看未缩放的原始图像，请选中左上方的“Show actual image size”。\n",
        "\n",
        "调整亮度和对比度滑块，查看它们如何影响图像像素。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bjACE1lAsqUd"
      },
      "source": [
        "## 可视化多个图像\n",
        "\n",
        "记录一个张量非常简单，但是要记录多个训练样本应如何操作？\n",
        "\n",
        "只需在向 `tf.summary.image()` 传递数据时指定要记录的图像数即可。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iHUjCXbetIpb"
      },
      "outputs": [],
      "source": [
        "with file_writer.as_default():\n",
        "  # Don't forget to reshape.\n",
        "  images = np.reshape(train_images[0:25], (-1, 28, 28, 1))\n",
        "  tf.summary.image(\"25 training data examples\", images, max_outputs=25, step=0)\n",
        "\n",
        "%tensorboard --logdir logs/train_data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fr6LFQG9UD6z"
      },
      "source": [
        " <img src=\"https://github.com/tensorflow/tensorboard/blob/master/docs/images/images_multiple.png?raw=1\"> "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c-7sZs3XuBBy"
      },
      "source": [
        "## 记录任意图像数据\n",
        "\n",
        "如果要可视化的图像并非张量（例如 [matplotlib](https://matplotlib.org/) 生成的图像），应如何操作？\n",
        "\n",
        "您需要一些样板代码来将图转换为张量，随后便可继续处理。\n",
        "\n",
        "在以下代码中，您将使用 matplotlib 的 `subplot()` 函数以美观的网格结构记录前 25 个图像。随后，您将在 TensorBoard 中查看该网格："
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "F5U_5WKt8bdQ"
      },
      "outputs": [],
      "source": [
        "# Clear out prior logging data.\n",
        "!rm -rf logs/plots\n",
        "\n",
        "logdir = \"logs/plots/\" + datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
        "file_writer = tf.summary.create_file_writer(logdir)\n",
        "\n",
        "def plot_to_image(figure):\n",
        "  \"\"\"Converts the matplotlib plot specified by 'figure' to a PNG image and\n",
        "  returns it. The supplied figure is closed and inaccessible after this call.\"\"\"\n",
        "  # Save the plot to a PNG in memory.\n",
        "  buf = io.BytesIO()\n",
        "  plt.savefig(buf, format='png')\n",
        "  # Closing the figure prevents it from being displayed directly inside\n",
        "  # the notebook.\n",
        "  plt.close(figure)\n",
        "  buf.seek(0)\n",
        "  # Convert PNG buffer to TF image\n",
        "  image = tf.image.decode_png(buf.getvalue(), channels=4)\n",
        "  # Add the batch dimension\n",
        "  image = tf.expand_dims(image, 0)\n",
        "  return image\n",
        "\n",
        "def image_grid():\n",
        "  \"\"\"Return a 5x5 grid of the MNIST images as a matplotlib figure.\"\"\"\n",
        "  # Create a figure to contain the plot.\n",
        "  figure = plt.figure(figsize=(10,10))\n",
        "  for i in range(25):\n",
        "    # Start next subplot.\n",
        "    plt.subplot(5, 5, i + 1, title=class_names[train_labels[i]])\n",
        "    plt.xticks([])\n",
        "    plt.yticks([])\n",
        "    plt.grid(False)\n",
        "    plt.imshow(train_images[i], cmap=plt.cm.binary)\n",
        "  \n",
        "  return figure\n",
        "\n",
        "# Prepare the plot\n",
        "figure = image_grid()\n",
        "# Convert to image and log\n",
        "with file_writer.as_default():\n",
        "  tf.summary.image(\"Training data\", plot_to_image(figure), step=0)\n",
        "\n",
        "%tensorboard --logdir logs/plots"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o_tIghRsXY7S"
      },
      "source": [
        " <img src=\"https://github.com/tensorflow/tensorboard/blob/master/docs/images/images_multiple.png?raw=1\"> "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vZx70BC1zhgW"
      },
      "source": [
        "## 构建图像分类器\n",
        "\n",
        "现在，让我们将其运用于实例当中。毕竟，我们是在研究机器学习，而不是绘制漂亮的图片！\n",
        "\n",
        "您将使用图像摘要来了解模型性能，同时为 Fashion-MNIST 数据集训练一个简单的分类器。\n",
        "\n",
        "首先，创建一个非常简单的模型并通过设置优化器和损失函数对该模型进行编译。在编译步骤中，还需指定您要定期记录其准确率的分类器。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "R74hPWJHzgvZ"
      },
      "outputs": [],
      "source": [
        "model = keras.models.Sequential([\n",
        "    keras.layers.Flatten(input_shape=(28, 28)),\n",
        "    keras.layers.Dense(32, activation='relu'),\n",
        "    keras.layers.Dense(10, activation='softmax')\n",
        "])\n",
        "\n",
        "model.compile(\n",
        "    optimizer='adam', \n",
        "    loss='sparse_categorical_crossentropy',\n",
        "    metrics=['accuracy']\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SdT_PpZB1UMn"
      },
      "source": [
        "训练分类器时，查看[混淆矩阵](https://en.wikipedia.org/wiki/Confusion_matrix)非常实用。混淆矩阵可帮助您详细了解分类器在测试数据上的性能。\n",
        "\n",
        "定义一个计算混淆矩阵的函数。您将使用便捷的 [Scikit-learn](https://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html) 函数进行定义，然后使用 matplotlib 绘制混淆矩阵。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rBiXP8-UO8t6"
      },
      "outputs": [],
      "source": [
        "def plot_confusion_matrix(cm, class_names):\n",
        "  \"\"\"\n",
        "  Returns a matplotlib figure containing the plotted confusion matrix.\n",
        "\n",
        "  Args:\n",
        "    cm (array, shape = [n, n]): a confusion matrix of integer classes\n",
        "    class_names (array, shape = [n]): String names of the integer classes\n",
        "  \"\"\"\n",
        "  figure = plt.figure(figsize=(8, 8))\n",
        "  plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)\n",
        "  plt.title(\"Confusion matrix\")\n",
        "  plt.colorbar()\n",
        "  tick_marks = np.arange(len(class_names))\n",
        "  plt.xticks(tick_marks, class_names, rotation=45)\n",
        "  plt.yticks(tick_marks, class_names)\n",
        "\n",
        "  # Normalize the confusion matrix.\n",
        "  cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)\n",
        "\n",
        "  # Use white text if squares are dark; otherwise black.\n",
        "  threshold = cm.max() / 2.\n",
        "  for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):\n",
        "    color = \"white\" if cm[i, j] > threshold else \"black\"\n",
        "    plt.text(j, i, cm[i, j], horizontalalignment=\"center\", color=color)\n",
        "\n",
        "  plt.tight_layout()\n",
        "  plt.ylabel('True label')\n",
        "  plt.xlabel('Predicted label')\n",
        "  return figure"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6lOAl_v26QGq"
      },
      "source": [
        "现在，您可以训练分类器并定期记录混淆矩阵了。\n",
        "\n",
        "您将执行以下操作：\n",
        "\n",
        "1. 创建 [Keras TensorBoard 回调](https://tensorflow.google.cn/api_docs/python/tf/keras/callbacks/TensorBoard)以记录基本指标\n",
        "2. 创建 [Keras LambdaCallback](https://tensorflow.google.cn/api_docs/python/tf/keras/callbacks/LambdaCallback) 以在每个周期结束时记录混淆矩阵\n",
        "3. 使用 Model.fit() 训练模型，确保传递两个回调\n",
        "\n",
        "随着训练的进行，向下滚动以查看 TensorBoard 启动情况。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "utd-vH6hn5RY"
      },
      "outputs": [],
      "source": [
        "# Clear out prior logging data.\n",
        "!rm -rf logs/image\n",
        "\n",
        "logdir = \"logs/image/\" + datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
        "# Define the basic TensorBoard callback.\n",
        "tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)\n",
        "file_writer_cm = tf.summary.create_file_writer(logdir + '/cm')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bXQ7-9CF0TPA"
      },
      "outputs": [],
      "source": [
        "def log_confusion_matrix(epoch, logs):\n",
        "  # Use the model to predict the values from the validation dataset.\n",
        "  test_pred_raw = model.predict(test_images)\n",
        "  test_pred = np.argmax(test_pred_raw, axis=1)\n",
        "\n",
        "  # Calculate the confusion matrix.\n",
        "  cm = sklearn.metrics.confusion_matrix(test_labels, test_pred)\n",
        "  # Log the confusion matrix as an image summary.\n",
        "  figure = plot_confusion_matrix(cm, class_names=class_names)\n",
        "  cm_image = plot_to_image(figure)\n",
        "\n",
        "  # Log the confusion matrix as an image summary.\n",
        "  with file_writer_cm.as_default():\n",
        "    tf.summary.image(\"Confusion Matrix\", cm_image, step=epoch)\n",
        "\n",
        "# Define the per-epoch callback.\n",
        "cm_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_confusion_matrix)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k6CV7dy-oJZu"
      },
      "outputs": [],
      "source": [
        "# Start TensorBoard.\n",
        "%tensorboard --logdir logs/image\n",
        "\n",
        "# Train the classifier.\n",
        "model.fit(\n",
        "    train_images,\n",
        "    train_labels,\n",
        "    epochs=5,\n",
        "    verbose=0, # Suppress chatty output\n",
        "    callbacks=[tensorboard_callback, cm_callback],\n",
        "    validation_data=(test_images, test_labels),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o7PnxGf8Ur6F"
      },
      "source": [
        " <img src=\"https://github.com/tensorflow/tensorboard/blob/master/docs/images/images_multiple.png?raw=1\">\n",
        "\n",
        " <img src=\"https://github.com/tensorflow/tensorboard/blob/master/docs/images/images_multiple.png?raw=1\"> "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6URWgszz9Jut"
      },
      "source": [
        "请注意，模型在训练集和验证集上的准确率都在提高。这是一个好迹象。但是，该模型在数据特定子集上的性能如何呢？\n",
        "\n",
        "选择“Images”标签页以可视化记录的混淆矩阵。选中左上方的“Show actual image size”以查看完整尺寸的混淆矩阵。\n",
        "\n",
        "默认情况下，信息中心会显示上次记录的步骤或周期的图像摘要。可以使用滑块查看更早的混淆矩阵。请注意矩阵随训练进行而发生的显著变化：深色的正方形会沿对角线聚集，而矩阵的其余部分则趋于 0 和白色。这意味着您的分类器的性能会随着训练的进行而不断提高！做得好！\n",
        "\n",
        "混淆矩阵表明此简单模型存在一些问题。尽管已取得重大进展，但在衬衫、T 恤和套头衫之间却出现混淆。该模型需要进一步完善。\n",
        "\n",
        "如果您有兴趣，请尝试使用[卷积神经网络](https://medium.com/tensorflow/hello-deep-learning-fashion-mnist-with-keras-50fcff8cd74a) (CNN) 改进此模型。"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "image_summaries.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
